

<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="zh" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="zh" > <!--<![endif]-->
<head>
  <meta charset="utf-8">
  
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  
  <title>对无望的 Trial 进行剪枝 (Pruning) &mdash; Optuna 1.4.0 文档</title>
  

  
  
    <link rel="shortcut icon" href="../_static/favicon.ico"/>
  
  
  

  
  <script type="text/javascript" src="../_static/js/modernizr.min.js"></script>
  
    
      <script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
        <script src="../_static/jquery.js"></script>
        <script src="../_static/underscore.js"></script>
        <script src="../_static/doctools.js"></script>
        <script src="../_static/language_data.js"></script>
        <script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
    
    <script type="text/javascript" src="../_static/js/theme.js"></script>

    

  
  <link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
  <link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
  <link rel="stylesheet" href="../_static/css/custom.css" type="text/css" />
    <link rel="index" title="索引" href="../genindex.html" />
    <link rel="search" title="搜索" href="../search.html" />
    <link rel="next" title="用户定义的采样器 (Sampler)" href="sampler.html" />
    <link rel="prev" title="用户定义属性" href="attributes.html" /> 
</head>

<body class="wy-body-for-nav">

   
  <div class="wy-grid-for-nav">
    
    <nav data-toggle="wy-nav-shift" class="wy-nav-side">
      <div class="wy-side-scroll">
        <div class="wy-side-nav-search" >
          

          
            <a href="../index.html" class="icon icon-home"> Optuna
          

          
          </a>

          
            
            
              <div class="version">
                1.4.0
              </div>
            
          

          
<div role="search">
  <form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
    <input type="text" name="q" placeholder="Search docs" />
    <input type="hidden" name="check_keywords" value="yes" />
    <input type="hidden" name="area" value="default" />
  </form>
</div>

          
        </div>

        <div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
          
            
            
              
            
            
              <p class="caption"><span class="caption-text">目录</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../installation.html">安装</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="index.html">教程</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="first.html">第一个优化例子</a></li>
<li class="toctree-l2"><a class="reference internal" href="configurations.html">高级配置</a></li>
<li class="toctree-l2"><a class="reference internal" href="rdb.html">用 RDB 后端保存/恢复 Study</a></li>
<li class="toctree-l2"><a class="reference internal" href="distributed.html">分布式优化</a></li>
<li class="toctree-l2"><a class="reference internal" href="cli.html">命令行界面</a></li>
<li class="toctree-l2"><a class="reference internal" href="attributes.html">用户定义属性</a></li>
<li class="toctree-l2 current"><a class="current reference internal" href="#">对无望的 Trial 进行剪枝 (Pruning)</a><ul>
<li class="toctree-l3"><a class="reference internal" href="#activating-pruners">开启 Pruner</a></li>
<li class="toctree-l3"><a class="reference internal" href="#integration-modules-for-pruning">用于 Pruning 的集成模块</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="sampler.html">用户定义的采样器 (Sampler)</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../reference/index.html">API Reference</a></li>
<li class="toctree-l1"><a class="reference internal" href="../faq.html">常见问题</a></li>
</ul>

            
          
        </div>
      </div>
    </nav>

    <section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">

      
      <nav class="wy-nav-top" aria-label="top navigation">
        
          <i data-toggle="wy-nav-top" class="fa fa-bars"></i>
          <a href="../index.html">Optuna</a>
        
      </nav>


      <div class="wy-nav-content">
        
        <div class="rst-content">
        
          















<div role="navigation" aria-label="breadcrumbs navigation">

  <ul class="wy-breadcrumbs">
    
      <li><a href="../index.html">Docs</a> &raquo;</li>
        
          <li><a href="index.html">教程</a> &raquo;</li>
        
      <li>对无望的 Trial 进行剪枝 (Pruning)</li>
    
    
      <li class="wy-breadcrumbs-aside">
        
            
            <a href="../_sources/tutorial/pruning.rst.txt" rel="nofollow"> View page source</a>
          
        
      </li>
    
  </ul>

  
  <hr/>
</div>
          <div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
           <div itemprop="articleBody">
            
  <div class="section" id="pruning-unpromising-trials">
<span id="pruning"></span><h1>对无望的 Trial 进行剪枝 (Pruning)<a class="headerlink" href="#pruning-unpromising-trials" title="永久链接至标题">¶</a></h1>
<p>该功能可以在训练的早期阶段自动终止无望的 Trial (a.k.a., 自动化 early-stopping). Optuna 提供了一些接口，可以用于在迭代训练算法中简洁地实现剪枝 (Pruning)。</p>
<div class="section" id="activating-pruners">
<h2>开启 Pruner<a class="headerlink" href="#activating-pruners" title="永久链接至标题">¶</a></h2>
<p>为了打开 Pruning 功能，你需要在迭代式训练的每一步完成后调用函数 <a class="reference internal" href="../reference/trial.html#optuna.trial.Trial.report" title="optuna.trial.Trial.report"><code class="xref py py-func docutils literal notranslate"><span class="pre">report()</span></code></a> 和 <a class="reference internal" href="../reference/trial.html#optuna.trial.Trial.should_prune" title="optuna.trial.Trial.should_prune"><code class="xref py py-func docutils literal notranslate"><span class="pre">should_prune()</span></code></a> <a class="reference internal" href="../reference/trial.html#optuna.trial.Trial.report" title="optuna.trial.Trial.report"><code class="xref py py-func docutils literal notranslate"><span class="pre">report()</span></code></a> 定期监测这个过程中的目标函数值。<a class="reference internal" href="../reference/trial.html#optuna.trial.Trial.should_prune" title="optuna.trial.Trial.should_prune"><code class="xref py py-func docutils literal notranslate"><span class="pre">should_prune()</span></code></a> 根据提前定义好的条件，判定该 trial 是否需要终止。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="sd">&quot;&quot;&quot;filename: prune.py&quot;&quot;&quot;</span>

<span class="kn">import</span> <span class="nn">sklearn.datasets</span>
<span class="kn">import</span> <span class="nn">sklearn.linear_model</span>
<span class="kn">import</span> <span class="nn">sklearn.model_selection</span>

<span class="kn">import</span> <span class="nn">optuna</span>

<span class="k">def</span> <span class="nf">objective</span><span class="p">(</span><span class="n">trial</span><span class="p">):</span>
    <span class="n">iris</span> <span class="o">=</span> <span class="n">sklearn</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">load_iris</span><span class="p">()</span>
    <span class="n">classes</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">iris</span><span class="o">.</span><span class="n">target</span><span class="p">))</span>
    <span class="n">train_x</span><span class="p">,</span> <span class="n">valid_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">valid_y</span> <span class="o">=</span> \
        <span class="n">sklearn</span><span class="o">.</span><span class="n">model_selection</span><span class="o">.</span><span class="n">train_test_split</span><span class="p">(</span><span class="n">iris</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">iris</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mf">0.25</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>

    <span class="n">alpha</span> <span class="o">=</span> <span class="n">trial</span><span class="o">.</span><span class="n">suggest_loguniform</span><span class="p">(</span><span class="s1">&#39;alpha&#39;</span><span class="p">,</span> <span class="mf">1e-5</span><span class="p">,</span> <span class="mf">1e-1</span><span class="p">)</span>
    <span class="n">clf</span> <span class="o">=</span> <span class="n">sklearn</span><span class="o">.</span><span class="n">linear_model</span><span class="o">.</span><span class="n">SGDClassifier</span><span class="p">(</span><span class="n">alpha</span><span class="o">=</span><span class="n">alpha</span><span class="p">)</span>

    <span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
        <span class="n">clf</span><span class="o">.</span><span class="n">partial_fit</span><span class="p">(</span><span class="n">train_x</span><span class="p">,</span> <span class="n">train_y</span><span class="p">,</span> <span class="n">classes</span><span class="o">=</span><span class="n">classes</span><span class="p">)</span>

        <span class="c1"># Report intermediate objective value.</span>
        <span class="n">intermediate_value</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">clf</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">valid_x</span><span class="p">,</span> <span class="n">valid_y</span><span class="p">)</span>
        <span class="n">trial</span><span class="o">.</span><span class="n">report</span><span class="p">(</span><span class="n">intermediate_value</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span>

        <span class="c1"># Handle pruning based on the intermediate value.</span>
        <span class="k">if</span> <span class="n">trial</span><span class="o">.</span><span class="n">should_prune</span><span class="p">():</span>
            <span class="k">raise</span> <span class="n">optuna</span><span class="o">.</span><span class="n">TrialPruned</span><span class="p">()</span>

    <span class="k">return</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">clf</span><span class="o">.</span><span class="n">score</span><span class="p">(</span><span class="n">valid_x</span><span class="p">,</span> <span class="n">valid_y</span><span class="p">)</span>

<span class="c1"># Set up the median stopping rule as the pruning condition.</span>
<span class="n">study</span> <span class="o">=</span> <span class="n">optuna</span><span class="o">.</span><span class="n">create_study</span><span class="p">(</span><span class="n">pruner</span><span class="o">=</span><span class="n">optuna</span><span class="o">.</span><span class="n">pruners</span><span class="o">.</span><span class="n">MedianPruner</span><span class="p">())</span>
<span class="n">study</span><span class="o">.</span><span class="n">optimize</span><span class="p">(</span><span class="n">objective</span><span class="p">,</span> <span class="n">n_trials</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span>
</pre></div>
</div>
<p>运行上述脚本:</p>
<div class="highlight-bash notranslate"><div class="highlight"><pre><span></span>$ python prune.py
<span class="o">[</span>I <span class="m">2020</span>-06-12 <span class="m">16</span>:54:23,876<span class="o">]</span> Trial <span class="m">0</span> finished with value: <span class="m">0</span>.3157894736842105 and parameters: <span class="o">{</span><span class="s1">&#39;alpha&#39;</span>: <span class="m">0</span>.00181467547181131<span class="o">}</span>. Best is trial <span class="m">0</span> with value: <span class="m">0</span>.3157894736842105.
<span class="o">[</span>I <span class="m">2020</span>-06-12 <span class="m">16</span>:54:23,981<span class="o">]</span> Trial <span class="m">1</span> finished with value: <span class="m">0</span>.07894736842105265 and parameters: <span class="o">{</span><span class="s1">&#39;alpha&#39;</span>: <span class="m">0</span>.015378744419287613<span class="o">}</span>. Best is trial <span class="m">1</span> with value: <span class="m">0</span>.07894736842105265.
<span class="o">[</span>I <span class="m">2020</span>-06-12 <span class="m">16</span>:54:24,083<span class="o">]</span> Trial <span class="m">2</span> finished with value: <span class="m">0</span>.21052631578947367 and parameters: <span class="o">{</span><span class="s1">&#39;alpha&#39;</span>: <span class="m">0</span>.04089428832878595<span class="o">}</span>. Best is trial <span class="m">1</span> with value: <span class="m">0</span>.07894736842105265.
<span class="o">[</span>I <span class="m">2020</span>-06-12 <span class="m">16</span>:54:24,185<span class="o">]</span> Trial <span class="m">3</span> finished with value: <span class="m">0</span>.052631578947368474 and parameters: <span class="o">{</span><span class="s1">&#39;alpha&#39;</span>: <span class="m">0</span>.004018735937374473<span class="o">}</span>. Best is trial <span class="m">3</span> with value: <span class="m">0</span>.052631578947368474.
<span class="o">[</span>I <span class="m">2020</span>-06-12 <span class="m">16</span>:54:24,303<span class="o">]</span> Trial <span class="m">4</span> finished with value: <span class="m">0</span>.07894736842105265 and parameters: <span class="o">{</span><span class="s1">&#39;alpha&#39;</span>: <span class="m">2</span>.805688697062864e-05<span class="o">}</span>. Best is trial <span class="m">3</span> with value: <span class="m">0</span>.052631578947368474.
<span class="o">[</span>I <span class="m">2020</span>-06-12 <span class="m">16</span>:54:24,315<span class="o">]</span> Trial <span class="m">5</span> pruned.
<span class="o">[</span>I <span class="m">2020</span>-06-12 <span class="m">16</span>:54:24,355<span class="o">]</span> Trial <span class="m">6</span> pruned.
<span class="o">[</span>I <span class="m">2020</span>-06-12 <span class="m">16</span>:54:24,511<span class="o">]</span> Trial <span class="m">7</span> finished with value: <span class="m">0</span>.052631578947368474 and parameters: <span class="o">{</span><span class="s1">&#39;alpha&#39;</span>: <span class="m">2</span>.243775785299103e-05<span class="o">}</span>. Best is trial <span class="m">3</span> with value: <span class="m">0</span>.052631578947368474.
<span class="o">[</span>I <span class="m">2020</span>-06-12 <span class="m">16</span>:54:24,625<span class="o">]</span> Trial <span class="m">8</span> finished with value: <span class="m">0</span>.1842105263157895 and parameters: <span class="o">{</span><span class="s1">&#39;alpha&#39;</span>: <span class="m">0</span>.007021209286214553<span class="o">}</span>. Best is trial <span class="m">3</span> with value: <span class="m">0</span>.052631578947368474.
<span class="o">[</span>I <span class="m">2020</span>-06-12 <span class="m">16</span>:54:24,629<span class="o">]</span> Trial <span class="m">9</span> pruned.
...
</pre></div>
</div>
<p>我们可以在输出信息中看到 <code class="docutils literal notranslate"><span class="pre">Setting</span> <span class="pre">status</span> <span class="pre">of</span> <span class="pre">trial#{}</span> <span class="pre">as</span> <span class="pre">TrialState.PRUNED</span></code>.这意味着这些 trial 在他们完成迭代之前就被终止了。</p>
</div>
<div class="section" id="integration-modules-for-pruning">
<h2>用于 Pruning 的集成模块<a class="headerlink" href="#integration-modules-for-pruning" title="永久链接至标题">¶</a></h2>
<p>为了能更加方便地实现 pruning, Optuna 为以下框架提供了集成模块。</p>
<ul class="simple">
<li><p>XGBoost: <a class="reference internal" href="../reference/integration.html#optuna.integration.XGBoostPruningCallback" title="optuna.integration.XGBoostPruningCallback"><code class="xref py py-class docutils literal notranslate"><span class="pre">optuna.integration.XGBoostPruningCallback</span></code></a></p></li>
<li><p>LightGBM: <a class="reference internal" href="../reference/integration.html#optuna.integration.LightGBMPruningCallback" title="optuna.integration.LightGBMPruningCallback"><code class="xref py py-class docutils literal notranslate"><span class="pre">optuna.integration.LightGBMPruningCallback</span></code></a></p></li>
<li><p>Chainer: <a class="reference internal" href="../reference/integration.html#optuna.integration.ChainerPruningExtension" title="optuna.integration.ChainerPruningExtension"><code class="xref py py-class docutils literal notranslate"><span class="pre">optuna.integration.ChainerPruningExtension</span></code></a></p></li>
<li><p>Keras: <a class="reference internal" href="../reference/integration.html#optuna.integration.KerasPruningCallback" title="optuna.integration.KerasPruningCallback"><code class="xref py py-class docutils literal notranslate"><span class="pre">optuna.integration.KerasPruningCallback</span></code></a></p></li>
<li><p>TensorFlow <a class="reference internal" href="../reference/integration.html#optuna.integration.TensorFlowPruningHook" title="optuna.integration.TensorFlowPruningHook"><code class="xref py py-class docutils literal notranslate"><span class="pre">optuna.integration.TensorFlowPruningHook</span></code></a></p></li>
<li><p>tf.keras <a class="reference internal" href="../reference/integration.html#optuna.integration.TFKerasPruningCallback" title="optuna.integration.TFKerasPruningCallback"><code class="xref py py-class docutils literal notranslate"><span class="pre">optuna.integration.TFKerasPruningCallback</span></code></a></p></li>
<li><p>MXNet <a class="reference internal" href="../reference/integration.html#optuna.integration.MXNetPruningCallback" title="optuna.integration.MXNetPruningCallback"><code class="xref py py-class docutils literal notranslate"><span class="pre">optuna.integration.MXNetPruningCallback</span></code></a></p></li>
<li><p>PyTorch Ignite <a class="reference internal" href="../reference/integration.html#optuna.integration.PyTorchIgnitePruningHandler" title="optuna.integration.PyTorchIgnitePruningHandler"><code class="xref py py-class docutils literal notranslate"><span class="pre">optuna.integration.PyTorchIgnitePruningHandler</span></code></a></p></li>
<li><p>PyTorch Lightning <a class="reference internal" href="../reference/integration.html#optuna.integration.PyTorchLightningPruningCallback" title="optuna.integration.PyTorchLightningPruningCallback"><code class="xref py py-class docutils literal notranslate"><span class="pre">optuna.integration.PyTorchLightningPruningCallback</span></code></a></p></li>
<li><p>FastAI <a class="reference internal" href="../reference/integration.html#optuna.integration.FastAIPruningCallback" title="optuna.integration.FastAIPruningCallback"><code class="xref py py-class docutils literal notranslate"><span class="pre">optuna.integration.FastAIPruningCallback</span></code></a></p></li>
</ul>
<p>比如, <a class="reference internal" href="../reference/integration.html#optuna.integration.XGBoostPruningCallback" title="optuna.integration.XGBoostPruningCallback"><code class="xref py py-class docutils literal notranslate"><span class="pre">XGBoostPruningCallback</span></code></a> 在无需修改训练迭代逻辑的情况下引入了 pruning.(完整脚本见 <a class="reference external" href="https://github.com/optuna/optuna/blob/master/examples/pruning/xgboost_integration.py">example</a> .)</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">pruning_callback</span> <span class="o">=</span> <span class="n">optuna</span><span class="o">.</span><span class="n">integration</span><span class="o">.</span><span class="n">XGBoostPruningCallback</span><span class="p">(</span><span class="n">trial</span><span class="p">,</span> <span class="s1">&#39;validation-error&#39;</span><span class="p">)</span>
<span class="n">bst</span> <span class="o">=</span> <span class="n">xgb</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">dtrain</span><span class="p">,</span> <span class="n">evals</span><span class="o">=</span><span class="p">[(</span><span class="n">dvalid</span><span class="p">,</span> <span class="s1">&#39;validation&#39;</span><span class="p">)],</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">pruning_callback</span><span class="p">])</span>
</pre></div>
</div>
</div>
</div>


           </div>
           
          </div>
          <footer>
  
    <div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
      
        <a href="sampler.html" class="btn btn-neutral float-right" title="用户定义的采样器 (Sampler)" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
      
      
        <a href="attributes.html" class="btn btn-neutral float-left" title="用户定义属性" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
      
    </div>
  

  <hr/>

  <div role="contentinfo">
    <p>
        &copy; Copyright 2018, Optuna Contributors.

    </p>
  </div>
  Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
    <a href="../privacy.html">Privacy Policy</a>.
     


</footer>

        </div>
      </div>

    </section>

  </div>
  


  <script type="text/javascript">
      jQuery(function () {
          SphinxRtdTheme.Navigation.enable(true);
      });
  </script>

  
  
    
    <!-- Theme Analytics -->
    <script>
    (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
      (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
      m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
    })(window,document,'script','https://www.google-analytics.com/analytics.js','ga');

    ga('create', 'UA-55135190-8', 'auto');
    ga('send', 'pageview');
    </script>

    
   

</body>
</html>